{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d3d85837",
   "metadata": {
    "origin_pos": 1
   },
   "source": [
    "# The Encoder--Decoder Architecture\n",
    ":label:`sec_encoder-decoder`\n",
    "\n",
    "In general sequence-to-sequence problems\n",
    "like machine translation\n",
    "(:numref:`sec_machine_translation`),\n",
    "inputs and outputs are of varying lengths\n",
    "that are unaligned.\n",
    "The standard approach to handling this sort of data\n",
    "is to design an *encoder--decoder* architecture (:numref:`fig_encoder_decoder`)\n",
    "consisting of two major components:\n",
    "an *encoder* that takes a variable-length sequence as input,\n",
    "and a *decoder* that acts as a conditional language model,\n",
    "taking in the encoded input\n",
    "and the leftwards context of the target sequence\n",
    "and predicting the subsequent token in the target sequence.\n",
    "\n",
    "\n",
    "![The encoder--decoder architecture.](../img/encoder-decoder.svg)\n",
    ":label:`fig_encoder_decoder`\n",
    "\n",
    "Let's take machine translation from English to French as an example.\n",
    "Given an input sequence in English:\n",
    "\"They\", \"are\", \"watching\", \".\",\n",
    "this encoder--decoder architecture\n",
    "first encodes the variable-length input into a state,\n",
    "then decodes the state\n",
    "to generate the translated sequence,\n",
    "token by token, as output:\n",
    "\"Ils\", \"regardent\", \".\".\n",
    "Since the encoder--decoder architecture\n",
    "forms the basis of different sequence-to-sequence models\n",
    "in subsequent sections,\n",
    "this section will convert this architecture\n",
    "into an interface that will be implemented later.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f6ad17a2",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:03.809994Z",
     "iopub.status.busy": "2023-08-18T19:29:03.809520Z",
     "iopub.status.idle": "2023-08-18T19:29:06.566331Z",
     "shell.execute_reply": "2023-08-18T19:29:06.565200Z"
    },
    "origin_pos": 3,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from d2l import torch as d2l"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15bee8b2",
   "metadata": {
    "origin_pos": 6
   },
   "source": [
    "## (**Encoder**)\n",
    "\n",
    "In the encoder interface,\n",
    "we just specify that\n",
    "the encoder takes variable-length sequences as input `X`.\n",
    "The implementation will be provided\n",
    "by any model that inherits this base `Encoder` class.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bb981983",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:06.571309Z",
     "iopub.status.busy": "2023-08-18T19:29:06.570433Z",
     "iopub.status.idle": "2023-08-18T19:29:06.577664Z",
     "shell.execute_reply": "2023-08-18T19:29:06.576626Z"
    },
    "origin_pos": 8,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):  #@save\n",
    "    \"\"\"The base encoder interface for the encoder--decoder architecture.\"\"\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    # Later there can be additional arguments (e.g., length excluding padding)\n",
    "    def forward(self, X, *args):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01ce15dd",
   "metadata": {
    "origin_pos": 11
   },
   "source": [
    "## [**Decoder**]\n",
    "\n",
    "In the following decoder interface,\n",
    "we add an additional `init_state` method\n",
    "to convert the encoder output (`enc_all_outputs`)\n",
    "into the encoded state.\n",
    "Note that this step\n",
    "may require extra inputs,\n",
    "such as the valid length of the input,\n",
    "which was explained\n",
    "in :numref:`sec_machine_translation`.\n",
    "To generate a variable-length sequence token by token,\n",
    "every time the decoder may map an input\n",
    "(e.g., the generated token at the previous time step)\n",
    "and the encoded state\n",
    "into an output token at the current time step.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7874f0f6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:06.581746Z",
     "iopub.status.busy": "2023-08-18T19:29:06.581040Z",
     "iopub.status.idle": "2023-08-18T19:29:06.587735Z",
     "shell.execute_reply": "2023-08-18T19:29:06.586641Z"
    },
    "origin_pos": 13,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):  #@save\n",
    "    \"\"\"The base decoder interface for the encoder--decoder architecture.\"\"\"\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "    # Later there can be additional arguments (e.g., length excluding padding)\n",
    "    def init_state(self, enc_all_outputs, *args):\n",
    "        raise NotImplementedError\n",
    "\n",
    "    def forward(self, X, state):\n",
    "        raise NotImplementedError"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddb7e5a2",
   "metadata": {
    "origin_pos": 16
   },
   "source": [
    "## [**Putting the Encoder and Decoder Together**]\n",
    "\n",
    "In the forward propagation,\n",
    "the output of the encoder\n",
    "is used to produce the encoded state,\n",
    "and this state will be further used\n",
    "by the decoder as one of its input.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a4223135",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-08-18T19:29:06.591527Z",
     "iopub.status.busy": "2023-08-18T19:29:06.590891Z",
     "iopub.status.idle": "2023-08-18T19:29:06.597930Z",
     "shell.execute_reply": "2023-08-18T19:29:06.596992Z"
    },
    "origin_pos": 17,
    "tab": [
     "pytorch"
    ]
   },
   "outputs": [],
   "source": [
    "class EncoderDecoder(d2l.Classifier):  #@save\n",
    "    \"\"\"The base class for the encoder--decoder architecture.\"\"\"\n",
    "    def __init__(self, encoder, decoder):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def forward(self, enc_X, dec_X, *args):\n",
    "        enc_all_outputs = self.encoder(enc_X, *args)\n",
    "        dec_state = self.decoder.init_state(enc_all_outputs, *args)\n",
    "        # Return decoder output only\n",
    "        return self.decoder(dec_X, dec_state)[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8f923bd",
   "metadata": {
    "origin_pos": 20
   },
   "source": [
    "In the next section,\n",
    "we will see how to apply RNNs to design\n",
    "sequence-to-sequence models based on\n",
    "this encoder--decoder architecture.\n",
    "\n",
    "\n",
    "## Summary\n",
    "\n",
    "Encoder-decoder architectures\n",
    "can handle inputs and outputs\n",
    "that both consist of variable-length sequences\n",
    "and thus are suitable for sequence-to-sequence problems\n",
    "such as machine translation.\n",
    "The encoder takes a variable-length sequence as input\n",
    "and transforms it into a state with a fixed shape.\n",
    "The decoder maps the encoded state of a fixed shape\n",
    "to a variable-length sequence.\n",
    "\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. Suppose that we use neural networks to implement the encoder--decoder architecture. Do the encoder and the decoder have to be the same type of neural network?\n",
    "1. Besides machine translation, can you think of another application where the encoder--decoder architecture can be applied?\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae1dc5a6",
   "metadata": {
    "origin_pos": 22,
    "tab": [
     "pytorch"
    ]
   },
   "source": [
    "[Discussions](https://discuss.d2l.ai/t/1061)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "required_libs": []
 },
 "nbformat": 4,
 "nbformat_minor": 5
}